Skip to content

[Aiter][ROCm] RMSNormGated+GroupedQuantFP8 fusion#40710

Open
tpopp wants to merge 10 commits into
vllm-project:mainfrom
tpopp:tpopp/gdn-rmsnorm-quant-fusion
Open

[Aiter][ROCm] RMSNormGated+GroupedQuantFP8 fusion#40710
tpopp wants to merge 10 commits into
vllm-project:mainfrom
tpopp:tpopp/gdn-rmsnorm-quant-fusion

Conversation

@tpopp
Copy link
Copy Markdown
Contributor

@tpopp tpopp commented Apr 23, 2026

This PR adds a compilation fusion pass (AiterRMSNormGatedFp8GroupQuantPattern) that fuses the decomposed RMSNormGated + reshape + group FP8 quantization sequence into a single AITER Triton kernel call (fused_rms_gated_fp8_group_quant). This pattern appears in GatedDeltaNetAttention layers (e.g., Qwen3-Next) where each attention head's output goes through gated RMS normalization, is reshaped back to the full hidden dimension, and then group-quantized to FP8 before the output projection linear layer.

Results:
a 9us set of 2 kernels can be combined to 4.5us. In the case of Qwen3Next, this can be a 1-3% improvement depending on how small the workload is (concurrency 1 vs 128).

Motivation

In models using GatedDeltaNetAttention (such as Qwen3-Next-80B-A3B-Instruct-FP8), the output path of each attention block performs:

  1. RMSNormGated on per-head tensors (N*H, D) with a gating tensor
  2. Reshape to (N, H*D)
  3. Group FP8 quantization with GroupShape(1, 128)

These three operations decompose into many elementwise and reduction kernels when torch.compile lowers them. By matching this pattern in the FX graph and replacing it with a single fused Triton kernel from AITER, we eliminate multiple GPU kernel launches and intermediate memory traffic.

Changes
• Register rocm_aiter_fused_rms_gated_fp8_group_quant custom op wrapping aiter.ops.triton.quant.fused_rms_gated_fp8_group_quant
• Add rocm_aiter_ops.are_gdn_triton_kernels_available() — checks whether the required AITER Triton kernels (causal_conv1d_update_single_token, gated_delta_net) are importable, allowing graceful fallback on older AITER builds that lack the GDN kernels
• rocm_aiter_fusion.py: Add AiterRMSNormGatedFp8GroupQuantPattern that matches the decomposed norm→reshape→quant graph and replaces it with the fused op. Add _fold_consecutive_reshapes pre-processing pass (needed because make_fx faithfully
records chained reshapes that must be folded for the pattern to match). Dynamically infer num_heads/head_dim from GatedDeltaNetAttention layers via static_forward_context. Gate the pattern on are_gdn_triton_kernels_available()
• matcher_utils.py: Add MatcherRMSNormGated pattern tracer that traces RMSNormGated.forward_static for use in pm.register_replacement. Extend MatcherQuantFP8 to support Triton-based quant op matching
• layernorm.py: Extract RMSNormGated.forward_static as a @staticmethod so both forward_native and the matcher can share the same pure-PyTorch implementation. forward_native delegates to it
• test_fusion.py: Add unit tests (TestGatedModel) for the fusion pattern covering positive match cases (aiter quant, non-aiter quant, per-token dynamic) and negative cases (wrong group shape, per-tensor quant)

AITER Dependency

The fused Triton kernel (fused_rms_gated_fp8_group_quant) is provided by ROCm/aiter#2423 (​https://github.com/ROCm/aiter/pull/2423​) ("[Triton] optimized decode kernels for Qwen3-Next model"). The fusion pass is gated behind rocm_aiter_ops.are_gdn_triton_kernels_available(), so it is a no-op on AITER versions that do not include this PR.

Benchmark Results

Setup:
• Model: Qwen/Qwen3-Next-80B-A3B-Instruct-FP8, TP=1
• GPU: AMD MI355x (gfx950), single GPU
• Base image: vllm/vllm-openai-rocm:nightly (vLLM v0.19.2rc1) with AITER rebuilt from aiter:main + PR #2423
• Attention backend: ROCM_AITER_FA
• Compilation: cudagraph_mode=FULL_AND_PIECEWISE, custom_ops=["-rms_norm", "-silu_and_mul", "+quant_fp8"], pass_config={"fuse_norm_quant": true}
• Benchmark command: vllm bench serve --dataset_name random --random_input_len 1024 --random_output_len 1024 --max_concurrency 4 --num_prompts 32 --num_warmups 4 --seed 1 --temperature 0 --ignore_eos

Pattern matching verification:
• With fusion: RocmAiterRMSNormQuantFusionPass replaced 5 patterns (1+2+2 across repeated-layer subgraphs — the 4 additional matches are from AiterRMSNormGatedFp8GroupQuantPattern)
• Without fusion (pattern commented out): replaced 1 pattern (only the existing non-gated AiterRMSNormDynamicQuantPattern)

Throughput (ISL=1024, OSL=1024, concurrency=4):

┌─────────────────────────────────┬─────────────┬──────────┬───────┐
│ Metric │ With Fusion │ Baseline │ Delta │
├─────────────────────────────────┼─────────────┼──────────┼───────┤
│ Output token throughput (tok/s) │ 467.05 │ 456.52 │ +2.3% │
│ Total token throughput (tok/s) │ 934.11 │ 913.04 │ +2.3% │
│ Mean TPOT (ms) │ 8.44 │ 8.66 │ −2.5% │
│ P99 TPOT (ms) │ 8.67 │ 8.98 │ −3.5% │
│ Mean E2EL (ms) │ 8,769 │ 8,971 │ −2.3% │
└─────────────────────────────────┴─────────────┴──────────┴───────┘

Accuracy (lm_eval, gsm8k, 5-shot):

┌──────────────────┬────────────────┬────────────────┬─────────────────────────────┐
│ Filter │ With Fusion │ Baseline │ Delta │
├──────────────────┼────────────────┼────────────────┼─────────────────────────────┤
│ flexible-extract │ 0.8605 ±0.0095 │ 0.8506 ±0.0098 │ +0.0099 (within error bars) │
│ strict-match │ 0.8089 ±0.0108 │ 0.8097 ±0.0108 │ −0.0008 (within error bars) │
└──────────────────┴────────────────┴────────────────┴─────────────────────────────┘

Accuracy is statistically identical — the fusion is numerically safe.

Test plan

• [x] Unit tests: pytest tests/compile/passes/test_fusion.py -k "gated" — positive and negative pattern match cases
• [x] lm_eval --tasks gsm8k --num_fewshot 5 — accuracy unchanged vs. baseline
• [x] vllm bench serve — throughput improved ~2.3%, TPOT improved ~2.5%
• [x] Verified graceful no-op when AITER lacks GDN kernels (are_gdn_triton_kernels_available() == False)

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added the rocm Related to AMD ROCm label Apr 23, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Apr 23, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 23, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tpopp.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 23, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces fusion support for RMSNormGated followed by FP8 group quantization on ROCm platforms using the aiter library. Key changes include the registration of a new fused custom operator, the implementation of a MatcherRMSNormGated class, and updates to the RocmAiterRMSNormQuantFusionPass to discover and fuse these patterns. Feedback focuses on critical safety issues regarding the global monkey-patching of the pattern matcher's type handling, which could lead to incorrect matches for other operators. Additionally, improvements were suggested to ensure the gated fusion pattern correctly supports both aiter and decomposed quantization variants and strictly validates the supported group size of 128 to prevent numerical errors.

Comment on lines +513 to +523
_orig_fx_to_pat = pm.fx_to_pattern

def _relaxed_fx_to_pattern(*a, **kw):
kw["ignore_types"] = (int, torch.SymInt)
return _orig_fx_to_pat(*a, **kw)

pm.fx_to_pattern = _relaxed_fx_to_pattern
try:
self.matched_count = self.patterns.apply(graph)
finally:
pm.fx_to_pattern = _orig_fx_to_pat
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Monkey-patching pm.fx_to_pattern to ignore all int and torch.SymInt types is extremely dangerous. This change affects all patterns registered in self.patterns, including those that rely on specific integer arguments for correctness (e.g., group_size=128 in AiterRMSFp8GroupQuantPattern). If a graph contains a quantization op with a different group size (e.g., 64), the matcher will incorrectly identify it as a match, leading to a replacement with a fused op that uses the wrong group size. This will cause silent numerical errors. A more targeted approach to handle SymInt in reshapes should be used instead of a global type ignore.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't found a better approach due to the shortcomings of the pytorch pattern matching based approach. This is becoming a common problem, especially when multiple reshapes exist.

Comment thread vllm/compilation/passes/fusion/rocm_aiter_fusion.py
Comment thread vllm/compilation/passes/fusion/rocm_aiter_fusion.py
@tpopp tpopp force-pushed the tpopp/gdn-rmsnorm-quant-fusion branch from 5c39363 to 2c82404 Compare April 23, 2026 17:09
@mergify mergify Bot removed the needs-rebase label Apr 23, 2026
@tpopp tpopp force-pushed the tpopp/gdn-rmsnorm-quant-fusion branch from 2c82404 to d4f1b17 Compare May 4, 2026 06:53
@tpopp tpopp force-pushed the tpopp/gdn-rmsnorm-quant-fusion branch 3 times, most recently from 31da8cb to 7b6683e Compare May 4, 2026 14:08
@tpopp
Copy link
Copy Markdown
Contributor Author

tpopp commented May 5, 2026

Some cleanup has been done and needs higher level feedback and a ready label to allow more complete testing.

@dllehr-amd
Copy link
Copy Markdown
Contributor

@gshtras Can you add the ready label for me?

@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label May 5, 2026
@tpopp tpopp force-pushed the tpopp/gdn-rmsnorm-quant-fusion branch from 5753895 to 3307453 Compare May 6, 2026 08:09
@tpopp
Copy link
Copy Markdown
Contributor Author

tpopp commented May 6, 2026

@tjtanaa seems to be the relevant CODEOWNER for this PR.

@tpopp tpopp force-pushed the tpopp/gdn-rmsnorm-quant-fusion branch 2 times, most recently from f7fa464 to a786d20 Compare May 11, 2026 09:29
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 11, 2026

Hi @tpopp, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@tpopp tpopp force-pushed the tpopp/gdn-rmsnorm-quant-fusion branch 2 times, most recently from 1d1f946 to 232157b Compare May 11, 2026 11:38
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 11, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tpopp.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 11, 2026
@tpopp tpopp force-pushed the tpopp/gdn-rmsnorm-quant-fusion branch from 232157b to 6881651 Compare May 11, 2026 11:50
tpopp and others added 4 commits May 11, 2026 07:18
…y check

Register fused_rms_gated_fp8_group_quant custom op that wraps the
aiter Triton kernel for fused gated RMSNorm + FP8 group quantization.
Also add are_gdn_triton_kernels_available() to check whether the
required aiter Triton kernels (conv1d single-token, gated delta net)
are importable, allowing graceful fallback on older aiter versions.

Made-with: Cursor

Signed-off-by: Tres Popp <tres.popp@amd.com>
Implement pattern matching and replacement for decomposed RMSNormGated
followed by group FP8 quantization, fusing them into a single aiter
Triton kernel (fused_rms_gated_fp8_group_quant).

Key changes:
- Add AiterRMSNormGatedFp8GroupQuantPattern in rocm_aiter_fusion.py
  that matches the decomposed norm+reshape+quant graph and replaces it
  with the fused op
- Extend MatcherQuantFP8 and MatcherRMSNormGated in matcher_utils.py
  to support the gated norm pattern tracing
- Add forward_static to RMSNormGated for code sharing with the matcher
  and have forward_native delegate to it
- Simplify input_quant_fp8.py by extracting shared logic into
  forward_static
- Dynamically infer num_heads/head_dim from GatedDeltaNetAttention
  layers via static_forward_context
- Register per-token dynamic quant patterns for both aiter and
  non-aiter quant ops to handle +/- quant_fp8 configurations
- Gate the gated pattern on are_gdn_triton_kernels_available()
- Add unit tests for the fusion pattern (positive and negative cases)

Made-with: Cursor

Signed-off-by: Tres Popp <tres.popp@amd.com>
- Remove unused MatcherFusedAddRMSNorm and its dead imports (RMSNorm,
  RMS_ADD_OP)
- Move fold_consecutive_reshapes to vllm_inductor_pass.py next to the
  related _fx_view_to_reshape helper
- Add docstrings to new _aiter_ops methods
  (fused_rms_gated_fp8_group_quant impl and getter)
- Check fused_rms_gated_fp8_group_quant importability in
  are_gdn_triton_kernels_available
- Restore docstring on RMSNormGated.forward_native

Signed-off-by: Tres Popp <trespopp@gmail.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
Iterate over use_triton for group quant patterns so both the CK and
triton backends are matched.  Use a set to deduplicate when quant_fp8
is disabled (forward_native is identical for both use_triton values).
Add a head_dim == 128 guard to AiterRMSNormGatedFp8GroupQuantPattern
since the fused kernel hardcodes group_size=head_dim.

Rename _fx_view_to_reshape to fx_view_to_reshape as it is not private.

Signed-off-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
@tpopp tpopp force-pushed the tpopp/gdn-rmsnorm-quant-fusion branch from 6881651 to fe5b13f Compare May 11, 2026 12:21
@mergify mergify Bot removed the needs-rebase label May 11, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 11, 2026

Hi @tpopp, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@tpopp tpopp force-pushed the tpopp/gdn-rmsnorm-quant-fusion branch from fe5b13f to 141e8c5 Compare May 11, 2026 12:30
The triton vs CK group quant op selection was added speculatively but
the approved PR vllm-project#41825 uses only the aiter (CK) group quant op.  Align
the pattern matching with that decision.

Signed-off-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@tpopp tpopp force-pushed the tpopp/gdn-rmsnorm-quant-fusion branch from 141e8c5 to 380c7ce Compare May 11, 2026 12:31
Revert the fx_view_to_reshape rename since the function already exists
upstream with the underscore prefix. Only apply the ignore_types
monkey-patch for the pattern matcher when gated norm patterns are
actually registered, avoiding interference with existing per-token
and per-tensor fusion patterns.

Signed-off-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 11, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tpopp.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 11, 2026
The gated RMSNorm + group FP8 quant pattern matches when the quant op
traces through native code (-quant_fp8) rather than the custom op.
Remove ops_in_model_before since the pre-fusion quant op depends on the
custom_ops config.

Signed-off-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@tpopp tpopp force-pushed the tpopp/gdn-rmsnorm-quant-fusion branch from 4ecce31 to 20e4933 Compare May 11, 2026 14:43
These parameters were unintentionally removed during earlier cleanup.
They are needed by the existing non-gated pattern registration logic.

Signed-off-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@tpopp
Copy link
Copy Markdown
Contributor Author

tpopp commented May 11, 2026

@tjtanaa Do you mind taking a look?

  • The current conflict is with a similar change that was made. For now, I've essentially copied the relevant change around quant_fp8 matching for MI35X.
  • I've cleared up the worst of the hacks done to work around type matching issues. The tests are verbose; I don't know if that's a problem.

I have a plan for a vllm_ir related form of rms_norm_gated but would like to do that as a follow up, so it's precisely targeted, and so I can separately clarify some details over how the IR and pattern matching behaves when custom ops are enabled.

@tpopp
Copy link
Copy Markdown
Contributor Author

tpopp commented May 11, 2026

I also haven't seen a better way to handle the Shape related data gathering. As far as I can tell, the reliance on constructing the same ops forces us to construct patterns with the exact constants derived, but pointers are welcomed if I'm wrong.

tpopp and others added 2 commits May 11, 2026 14:19
Pass match_aiter_quant through to super().__init__ instead of creating
a separate MatcherQuantFP8. The base class already creates the matcher
with the correct quant key.

Signed-off-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
…-fusion

Co-authored-by: Cursor <cursoragent@cursor.com>

# Conflicts:
#	vllm/compilation/passes/fusion/matcher_utils.py
#	vllm/compilation/passes/fusion/rocm_aiter_fusion.py

Signed-off-by: Tres Popp <tres.popp@amd.com>
@mergify mergify Bot removed the needs-rebase label May 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

3 participants